import pymc as pm
import numpy as np
import pandas as pd
import arviz as az
import pytensor.tensor as pt
from scipy.interpolate import BSplinehello world
If you look to Odysseus on the morning the gates of Troy fell, he is well set up for a happy journey home. He is the architect of victory, his ships are loaded with spoils, and the wind is at his back. Yet, an odyssey can’t be completed in a single day and conclusions drawn on the outset rarely survive journey’s end.
When we rely on static snapshots, like a single blood draw or particular sales campaign
df = pd.read_csv("aalen_simdata.csv")
df = df[['subject', 'x', 'dose', 'M', 'start', 'stop', 'event']]
df.head()| subject | x | dose | M | start | stop | event | |
|---|---|---|---|---|---|---|---|
| 0 | 1 | 0 | ctrl | 6.74 | 0 | 4.00 | 0 |
| 1 | 1 | 0 | ctrl | 6.91 | 4 | 8.00 | 0 |
| 2 | 1 | 0 | ctrl | 6.90 | 8 | 12.00 | 0 |
| 3 | 1 | 0 | ctrl | 6.71 | 12 | 26.00 | 0 |
| 4 | 1 | 0 | ctrl | 6.45 | 26 | 46.85 | 1 |
df.groupby(['x', 'dose'])[['event', 'M']].agg(['mean', 'sum'])| event | M | ||||
|---|---|---|---|---|---|
| mean | sum | mean | sum | ||
| x | dose | ||||
| 0 | ctrl | 0.164179 | 66 | 6.996915 | 2812.76 |
| 1 | high | 0.119205 | 54 | 8.081589 | 3660.96 |
| low | 0.139037 | 52 | 7.302620 | 2731.18 | |
Code
import matplotlib.pyplot as plt
import pandas as pd
# Derive subject-level info for ordering
subject_info = (
df.groupby('subject')
.agg(
x=('x', 'first'),
max_stop=('stop', 'max')
)
.sort_values(['x', 'max_stop'])
)
subjects = subject_info.index.tolist()
subject_to_y = {s: i for i, s in enumerate(subjects)}
fig, ax = plt.subplots(figsize=(8, 0.1 * len(subjects)))
for _, row in df.iterrows():
y = subject_to_y[row['subject']]
color = 'tab:blue' if row['x'] == 1 else 'tab:orange'
ax.hlines(
y=y,
xmin=row['start'],
xmax=row['stop'],
color=color,
linewidth=3
)
if row['event'] == 1:
ax.plot(
row['stop'],
y,
marker='o',
color='red',
markersize=6,
zorder=3
)
# Axis formatting
ax.set_yticks(range(len(subjects)))
ax.set_yticklabels(subjects)
ax.set_xlabel("Time")
ax.set_ylabel("Subject")
# Visual separation between treatment groups
x0_count = (subject_info['x'] == 0).sum()
ax.axhline(x0_count - 0.5, color='black', linestyle='--', linewidth=1)
# Legend
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], color='tab:blue', lw=3, label='x = 1'),
Line2D([0], [0], color='tab:orange', lw=3, label='x = 0'),
Line2D([0], [0], marker='o', color='red', lw=0, label='Event', markersize=6)
]
ax.legend(handles=legend_elements, loc='upper right')
ax.set_title("Subject Timelines Ordered by Treatment Level")
plt.tight_layout()
plt.show()Data Preparation
def prepare_aalen_dpa_data(
df,
subject_col="subject",
start_col="start",
stop_col="stop",
event_col="event",
x_col="x",
m_col="M",
):
"""
Prepare Andersen–Gill / Aalen dynamic path data for PyMC.
Parameters
----------
df : pd.DataFrame
Long-format start–stop survival data
subject_col : str
Subject identifier
start_col, stop_col : str
Interval boundaries
event_col : str
Event indicator (0/1)
x_col : str
Exposure / treatment
m_col : str
Mediator measured at interval start
Returns
-------
dict
Dictionary of numpy arrays ready for PyMC
"""
df = df.copy()
# -------------------------------------------------
# 1. Basic quantities
# -------------------------------------------------
df["dt"] = df[stop_col] - df[start_col]
if (df["dt"] <= 0).any():
raise ValueError("Non-positive interval lengths detected.")
N = df[event_col].astype(int).values
Y = np.ones(len(df), dtype=int) # Andersen–Gill at-risk indicator
# -------------------------------------------------
# 2. Time-bin indexing (piecewise-constant effects)
# -------------------------------------------------
bins = (
df[[start_col, stop_col]]
.drop_duplicates()
.sort_values([start_col, stop_col])
.reset_index(drop=True)
)
bins["bin_idx"] = np.arange(len(bins))
df = df.merge(
bins,
on=[start_col, stop_col],
how="left",
validate="many_to_one"
)
bin_idx = df["bin_idx"].values
n_bins = bins.shape[0]
# -------------------------------------------------
# 3. Center covariates (important for Aalen models)
# -------------------------------------------------
df["x_c"] = df[x_col]
df["m_c"] = df[m_col] - df[m_col].mean()
x = df["x_c"].values
m = df["m_c"].values
# -------------------------------------------------
# 4. Predictable mediator (lag within subject)
# -------------------------------------------------
df = df.sort_values([subject_col, start_col])
df["m_lag"] = (
df.groupby(subject_col)["m_c"]
.shift(1)
.fillna(0.0)
)
m_lag = df["m_lag"].values
df["I_low"] = (df["dose"] == "low").astype(int)
df["I_high"] = (df["dose"] == "high").astype(int)
# -------------------------------------------------
# 5. Assemble output
# -------------------------------------------------
data = {
"bins": bins, # useful for plotting
"df_long": df # optional: debugging / inspection
}
return datadata = prepare_aalen_dpa_data(df)
df_long = data['df_long']
df_long[['subject', 'x', 'dose', 'M', 'event', 'dt', 'bin_idx']].head(14)| subject | x | dose | M | event | dt | bin_idx | |
|---|---|---|---|---|---|---|---|
| 0 | 1 | 0 | ctrl | 6.74 | 0 | 4.00 | 7 |
| 1 | 1 | 0 | ctrl | 6.91 | 0 | 4.00 | 13 |
| 2 | 1 | 0 | ctrl | 6.90 | 0 | 4.00 | 23 |
| 3 | 1 | 0 | ctrl | 6.71 | 0 | 14.00 | 53 |
| 4 | 1 | 0 | ctrl | 6.45 | 1 | 20.85 | 81 |
| 5 | 2 | 1 | high | 6.11 | 0 | 4.00 | 7 |
| 6 | 2 | 1 | high | 6.28 | 0 | 4.00 | 13 |
| 7 | 2 | 1 | high | 7.04 | 0 | 4.00 | 23 |
| 8 | 2 | 1 | high | 6.93 | 0 | 14.00 | 53 |
| 9 | 2 | 1 | high | 7.86 | 0 | 26.00 | 89 |
| 10 | 2 | 1 | high | 8.47 | 0 | 26.00 | 115 |
| 11 | 2 | 1 | high | 8.91 | 0 | 26.00 | 137 |
| 12 | 2 | 1 | high | 8.99 | 0 | 52.00 | 162 |
| 13 | 2 | 1 | high | 9.36 | 0 | 104.00 | 188 |
def create_bspline_basis(n_bins, n_knots=10, degree=3):
"""
Create B-spline basis functions for smooth time-varying effects.
Parameters
----------
n_bins : int
Number of time bins
n_knots : int
Number of internal knots (fewer = smoother)
degree : int
Degree of spline (3 = cubic, recommended)
Returns
-------
basis : np.ndarray
Matrix of shape (n_bins, n_basis) with basis function values
"""
# Create knot sequence
# Internal knots equally spaced across time range
internal_knots = np.linspace(0, n_bins-1, n_knots)
# Add boundary knots (repeated degree+1 times for clamped spline)
knots = np.concatenate([
np.repeat(internal_knots[0], degree),
internal_knots,
np.repeat(internal_knots[-1], degree)
])
# Number of basis functions
n_basis = len(knots) - degree - 1
# Evaluate each basis function at each time point
t = np.arange(n_bins, dtype=float)
basis = np.zeros((n_bins, n_basis))
for i in range(n_basis):
# Create coefficient vector (indicator for basis i)
coef = np.zeros(n_basis)
coef[i] = 1.0
# Evaluate B-spline
spline = BSpline(knots, coef, degree, extrapolate=False)
basis[:, i] = spline(t)
return basis
n_knots = 10
n_bins = data['bins'].shape[0]
basis = create_bspline_basis(n_bins, n_knots=n_knots, degree=3)
n_cols = basis.shape[1]
basis_df = pd.DataFrame(basis, columns=[f'feature_{i}' for i in range(n_cols)])
basis_df.head(10)| feature_0 | feature_1 | feature_2 | feature_3 | feature_4 | feature_5 | feature_6 | feature_7 | feature_8 | feature_9 | feature_10 | feature_11 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 1 | 0.863149 | 0.133496 | 0.003337 | 0.000018 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 2 | 0.739389 | 0.247518 | 0.012946 | 0.000146 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 3 | 0.628064 | 0.343219 | 0.028223 | 0.000494 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 4 | 0.528515 | 0.421749 | 0.048566 | 0.001170 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 5 | 0.440083 | 0.484261 | 0.073370 | 0.002286 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 6 | 0.362110 | 0.531908 | 0.102032 | 0.003950 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 7 | 0.293939 | 0.565840 | 0.133949 | 0.006272 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 8 | 0.234909 | 0.587211 | 0.168518 | 0.009362 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 9 | 0.184365 | 0.597171 | 0.205134 | 0.013330 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
def make_model(data, basis, sample=True, observed=True):
df_long = data['df_long'].copy()
n_basis = basis.shape[1]
n_obs = data['df_long'].shape[0]
time_bins = data['bins']['bin_idx'].values
b = df_long['bin_idx']
observed_mediator = df_long["m_c"].values
observed_events = df_long['event'].astype(int).values
observed_treatment = df_long['x'].astype(int).values
observed_mediator_lag = df_long['m_lag'].values
coords = {'tv': ['intercept', 'direct', 'mediator'],
'splines': ['spline_f_{i}' for i in range(n_basis)],
'obs': range(n_obs),
'time_bins': time_bins}
with pm.Model(coords=coords) as aalen_dpa_model:
trt = pm.Data("trt", observed_treatment, dims="obs")
med = pm.Data("mediator", observed_mediator, dims="obs")
med_lag = pm.Data("mediator_lag", observed_mediator_lag, dims="obs")
events = pm.Data("events", observed_events, dims="obs")
I_low = pm.Data("I_low", df_long["I_low"].values, dims="obs")
I_high = pm.Data("I_high", df_long["I_high"].values, dims="obs")
dt = pm.Data("duration", df_long['dt'].values, dims='obs')
## because our long data format has a cell per obs
at_risk = pm.Data("at_risk", np.ones(len(observed_events)), dims="obs")
basis_ = pm.Data("basis", basis, dims=('time_bins', 'splines') )
# -------------------------------------------------
# 1. B-spline coefficients for HAZARD model
# -------------------------------------------------
# Prior on spline coefficients
# Smaller sigma = less wiggliness
# Random Walk 1 (RW1) Prior for coefficients
# This is the Bayesian version of the smoothing penalty in R's 'mgcv' or 'timereg'
sigma_smooth = pm.Exponential("sigma_smooth", [1, 1, 1], dims='tv')
beta_raw = pm.Normal("beta_raw", 0, 1, dims=('splines', 'tv'))
# Cumulative sum makes it a Random Walk
# This ensures coefficients evolve smoothly over time
coef_alpha = pm.Deterministic("coef_alpha", pt.cumsum(beta_raw * sigma_smooth, axis=0), dims=('splines', 'tv'))
# Construct smooth time-varying functions
alpha_0_t = pt.dot(basis_, coef_alpha[:, 0])
alpha_1_t = pt.dot(basis_, coef_alpha[:, 1])
alpha_2_t = pt.dot(basis_, coef_alpha[:, 2])
# -------------------------------------------------
# 2. B-spline coefficients for MEDIATOR model
# -------------------------------------------------
sigma_beta_smooth = pm.Exponential("sigma_beta_smooth", 0.1)
beta_raw = pm.Normal("beta_raw_m", 0, 1, dims=('splines'))
coef_beta = pt.cumsum(beta_raw * sigma_beta_smooth)
beta_t = pt.dot(basis_, coef_beta)
# -------------------------------------------------
# 3. Mediator model (A path: x → M)
# -------------------------------------------------
sigma_m = pm.HalfNormal("sigma_m", 1.0)
# Autoregressive component
rho = pm.Beta("rho", 2, 2)
mu_m = beta_t[b] * trt + rho * med_lag
pm.Normal(
"obs_m",
mu=mu_m,
sigma=sigma_m,
observed=med,
dims='obs'
)
# -------------------------------------------------
# 4. Hazard model (direct + B path)
# -------------------------------------------------
beta_low = pm.Normal("beta_low", 0, 0.1)
beta_high = pm.Normal("beta_high", 0, 0.1)
# Log-additive hazard
log_lambda_t = (alpha_0_t[b]
+ alpha_1_t[b] * trt # direct effect
+ alpha_2_t[b] * med # mediator effect
+ beta_low * I_low
+ beta_high * I_high
)
# Expected number of events
time_at_risk = at_risk * dt
Lambda = time_at_risk * pm.math.log1pexp(log_lambda_t)
if observed:
pm.Poisson(
"obs_event",
mu=Lambda,
observed=events,
dims='obs'
)
else:
pm.Poisson(
"obs_event",
mu=Lambda,
dims='obs'
)
# -------------------------------------------------
# 5. Causal path effects
# -------------------------------------------------
# Store time-varying coefficients
pm.Deterministic("alpha_0_t", alpha_0_t, dims='time_bins')
pm.Deterministic("alpha_1_t", alpha_1_t, dims='time_bins') # direct effect
pm.Deterministic("alpha_2_t", alpha_2_t, dims='time_bins') # B path
pm.Deterministic("beta_t", beta_t, dims='time_bins') # A path
# Cumulative direct effect
cum_de = pm.Deterministic(
"tv_direct_effect",
alpha_1_t,
dims='time_bins'
)
# Cumulative indirect effect (product of paths)
cum_ie = pm.Deterministic(
"tv_indirect_effect",
beta_t * alpha_2_t,
dims='time_bins'
)
# Total effect
cum_te = pm.Deterministic(
"tv_total_effect",
cum_de + cum_ie,
dims='time_bins'
)
pm.Deterministic('tv_baseline_hazard', pm.math.log1pexp(alpha_0_t),
dims='time_bins')
pm.Deterministic('tv_hazard_with_exposure', pm.math.log1pexp(alpha_0_t + alpha_1_t),
dims='time_bins')
pm.Deterministic(
"tv_RR",
pm.math.log1pexp(alpha_0_t + alpha_1_t) /
pm.math.log1pexp(alpha_0_t),
dims="time_bins"
)
# -------------------------------------------------
# 6. Sample
# -------------------------------------------------
if sample:
idata = pm.sample_prior_predictive()
idata.extend(pm.sample(
draws=2000,
tune=2000,
target_accept=0.95,
chains=4,
nuts_sampler="numpyro",
random_seed=42,
init="adapt_diag",
idata_kwargs={"log_likelihood": True}
))
idata.extend(pm.sample_posterior_predictive(idata))
return aalen_dpa_model, idata
basis = create_bspline_basis(n_bins, n_knots=12, degree=3)
aalen_dpa_model, idata_aalen = make_model(data, basis)pm.model_to_graphviz(aalen_dpa_model)models = {}
idatas = {}
for i in range(4, 15, 2):
basis = create_bspline_basis(n_bins, n_knots=i, degree=3)
aalen_dpa_model, idata = make_model(data, basis)
models[i] = aalen_dpa_model
idatas[f"splines_{i}"] = idata
compare_df = az.compare(idatas, var_name='obs_event')
az.plot_compare(compare_df, figsize=(8, 6), plot_ic_diff=True)ax = az.plot_forest([idatas[k] for k in idatas.keys()], combined=True, var_names=['tv_direct_effect'], model_names=idatas.keys(), coords={'time_bins': [180, 182, 182, 183, 184, 185, 186, 187, 188]},
figsize=(12, 10), r_hat=True)
ax[0].set_title("Time Vary Direct Effects \n Comparing Models on Final Time Intervals", fontsize=15)
ax[0].set_ylabel("Nth Time Interval", fontsize=15)
fig = ax[0].figure
fig.savefig('forest_plot_comparing_tv_direct.png')az.plot_trace(idata_aalen, var_names=['tv_direct_effect', 'tv_indirect_effect', 'tv_total_effect', 'beta_high', 'beta_low'], divergences=False);
plt.tight_layout()vars_to_plot = ['tv_direct_effect', 'tv_indirect_effect', 'tv_total_effect']
labels = ['Time varying Direct Effect', 'Time varying Indirect Effect', 'Time varying Total Effect']
def plot_effects(idata, vars_to_plot, labels, scale="Log Hazard Ratio Scale"):
fig, axs = plt.subplots(1, 3, figsize=(20, 10))
color='teal'
if scale != "Log Hazard Ratio Scale":
color='darkred'
for i, var in enumerate(vars_to_plot):
# 1. Extract the posterior samples for this variable
# Shape will be (chain * draw, time)
post_samples = az.extract(idata, var_names=[var]).values.T
# 2. Calculate the mean and the 94% HDI across the chains/draws
mean_val = post_samples.mean(axis=0)
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
# 3. Plot the Mean line
x_axis = np.arange(len(mean_val))
axs[i].plot(x_axis, mean_val, label=labels[i], color=color, lw=2)
# 4. Plot the Shaded HDI region
axs[i].fill_between(x_axis, hdi_val[:, 0], hdi_val[:, 1], color=color, alpha=0.2, label='94% HDI')
# Formatting
axs[i].set_title(labels[i])
axs[i].legend()
axs[i].grid(alpha=0.3)
axs[i].set_ylabel(scale)
plt.tight_layout()
return fig
plot_effects(idata_aalen, vars_to_plot, labels);/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
vars_to_plot = ['tv_baseline_hazard', 'tv_hazard_with_exposure', 'tv_RR']
labels = ['Time varying Baseline Hazard', 'Time varying Hazard + Exposure', 'Time varying RR']
plot_effects(idata_aalen, vars_to_plot, labels, scale='Hazard Scale');/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
Citation
BibTeX citation:
@online{forde,
author = {Forde, Nathaniel},
title = {Aalen’s {Dynamic} {Path} {Model}},
langid = {en}
}
For attribution, please cite this work as:
Forde, Nathaniel. n.d. “Aalen’s Dynamic Path Model.”